Search Options

Results per page
Sort
Preferred Languages
Advance

Results 1 - 10 of 32 for getShape (0.28 sec)

  1. tensorflow/compiler/mlir/quantization/common/ir/UniformSupport.cc

      if (auto tensor_type = dyn_cast<RankedTensorType>(input_type))
        return RankedTensorType::get(tensor_type.getShape(), elemental_type);
      if (auto tensor_type = dyn_cast<UnrankedTensorType>(input_type))
        return UnrankedTensorType::get(elemental_type);
      if (auto vector_type = dyn_cast<VectorType>(input_type))
        return VectorType::get(vector_type.getShape(), elemental_type);
    
      // If the expressed types match, just use the new elemental type.
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Wed May 08 02:10:16 UTC 2024
    - 4.3K bytes
    - Viewed (0)
  2. tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.cc

        auto rhs_type = mlir::dyn_cast<RankedTensorType>(input_rhs.getType());
    
        if (!lhs_type || !rhs_type) return failure();
    
        auto lhs_shape = lhs_type.getShape();
        auto rhs_shape = rhs_type.getShape();
    
        // Ensure that input ranks are at least 2.
        const int dims_a = lhs_shape.size();
        const int dims_b = rhs_shape.size();
        if (dims_a < 2 || dims_b < 2) {
          return failure();
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Thu Apr 25 16:01:03 UTC 2024
    - 3.8K bytes
    - Viewed (0)
  3. tensorflow/compiler/mlir/tensorflow/transforms/fold_broadcast.cc

        // Get the unbroadcasted shapes in the operand order.
        std::array<llvm::ArrayRef<int64_t>, 2> operand_shapes;
        operand_shapes[i] = broadcast_arg_type.getShape();
        operand_shapes[1 - i] = argument_type.getShape();
    
        // Check that the input of the broadcast and the other operand is broadcast
        // compatible.
        llvm::SmallVector<int64_t, 4> broadcasted_shape;
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Thu Apr 25 16:01:03 UTC 2024
    - 7.9K bytes
    - Viewed (0)
  4. tensorflow/compiler/mlir/lite/utils/utils.td

    def CreateNoneValue : NativeCodeCall<
      "$_builder.create<TFL::NoValueOp>($0.getLoc(), $_builder.getUnitAttr())">;
    
    // Returns shape of a ranked tensor.
    // if called without a ranked tensor it will fail.
    def GetShape: NativeCodeCall<"GetShape($0)">;
    
    // Constraint that values in list attribute are all ones.
    def IsAllOnesConstant : Constraint<CPred<"TFL::IsAllOnesConstant($0)">>;
    
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Tue Apr 30 00:40:15 UTC 2024
    - 4.8K bytes
    - Viewed (0)
  5. tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_xla_attribute_utils_test.cc

      ShapedType packed_shape_type =
          mlir::dyn_cast<ShapedType>(packed_value.getType());
      llvm::SmallVector<int64_t> packed_shape(packed_shape_type.getShape().begin(),
                                              packed_shape_type.getShape().end());
      EXPECT_THAT(packed_shape, testing::ElementsAreArray(expected_packed_shape));
      llvm::SmallVector<int8_t> packed_value_vector(
          packed_value_attr.getValues<int8_t>());
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Thu Apr 25 16:01:03 UTC 2024
    - 3.5K bytes
    - Viewed (0)
  6. tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_avg_pool.cc

      TorchAvgPoolData data;
    
      auto op_type = mlir::cast<RankedTensorType>(op.getOperand(0).getType());
    
      data.n = op_type.getShape()[0];
      data.c = op_type.getShape()[1];
      data.h_in = op_type.getShape()[2];
      data.w_in = op_type.getShape()[3];
    
      std::vector<int32_t> kernel_size;
      GetI32VectorFromDenseI64CompositeAttr(composite_attrs, "kernel_size",
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Tue May 28 23:16:05 UTC 2024
    - 9.2K bytes
    - Viewed (0)
  7. tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc

        // Thus, we fail to match if the consuming reshape rank is larger.
        ArrayRef<int64_t> input_shape = input_type.getShape();
        if (reshape_shape.size() > input_shape.size()) return failure();
    
        // Extend the input shape with leading 1s to match the broadcast shape.
        ArrayRef<int64_t> broadcast_shape = output_type.getShape();
        SmallVector<int64_t, 4> input_shape_extended;
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Thu Apr 25 16:01:03 UTC 2024
    - 8.1K bytes
    - Viewed (0)
  8. tensorflow/compiler/mlir/lite/transforms/optimize_batch_matmul.cc

              bmm_op->getLoc(), permuation_tensor_type,
              DenseElementsAttr::get(permuation_tensor_type, permute));
    
          auto input_shape = input_type.getShape();
          llvm::SmallVector<int64_t, 4> permuted_shape(input_shape.begin(),
                                                       input_shape.end());
          // Swaps z dimension and x dimension to get permuted shape.
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Thu Apr 25 16:01:03 UTC 2024
    - 9.6K bytes
    - Viewed (0)
  9. tensorflow/compiler/mlir/tensorflow/utils/export_utils.h

    // ShapeContainerT is any type with the following methods:
    //   bool hasRank()
    //   ArrayRef<int64_t> getShape()
    // This includes mlir::TF::ShapeAttr and mlir::ShapedType.
    template <typename ShapeContainerT>
    void SetTensorShapeProto(ShapeContainerT shape, TensorShapeProto* proto) {
      if (shape.hasRank()) {
        for (int64_t dim : shape.getShape()) {
          proto->add_dim()->set_size(mlir::ShapedType::isDynamic(dim) ? -1 : dim);
        }
      } else {
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Fri Apr 26 09:37:10 UTC 2024
    - 3.9K bytes
    - Viewed (0)
  10. tensorflow/compiler/mlir/tensorflow/ir/tf_arith_ops_folder.h

        return operand_ty == result_ty && identity_ty.hasStaticShape() &&
               result_ty.hasStaticShape() &&
               OpTrait::util::staticallyKnownBroadcastable(operand_ty.getShape(),
                                                           identity_ty.getShape());
      };
    
      // Check that we have a constant operand on one side (candidate for identity).
      const bool is_commutative =
    Registered: Sun Jun 16 05:45:23 UTC 2024
    - Last Modified: Thu Apr 25 16:01:03 UTC 2024
    - 5.3K bytes
    - Viewed (0)
Back to top